import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os
import argparse
import logging
from scipy import integrate
from scipy.linalg import sqrtm
from scipy import integrate
import copy

def marginal_prob_std(t, sigma):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  t = t.to(device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  return sigma**t.to(device)

## The error tolerance for the black-box ODE solver
error_tolerance = 1e-5 #@param {'type': 'number'}
def ode_sampler(score_model,
                marginal_prob_std,
                diffusion_coeff,
                batch_size=64,
                atol=error_tolerance,
                rtol=error_tolerance,
                init_x=0,
                device='cuda',
                z=None,
                sigma = 25.0,
                eps=1e-3):
  
  t = torch.ones(batch_size, device=device)
  # Create the latent code
  
  shape = init_x.shape

  def score_eval_wrapper(sample, time_steps):
    """A wrapper of the score-based model for use by the ODE solver."""
    sample = torch.tensor(sample, device=device, dtype=torch.float32).reshape(shape)
    time_steps = torch.tensor(time_steps, device=device, dtype=torch.float32).reshape((sample.shape[0], ))
    with torch.no_grad():
      score = score_model.score(sample, time_steps)
    return score.cpu().numpy().reshape((-1,)).astype(np.float64)

  def ode_func(t, x):
    """The ODE function for use by the ODE solver."""
    time_steps = np.ones((shape[0],)) * t
    g = diffusion_coeff(torch.tensor(t), sigma).cpu().numpy()
    return -0.5 * (g**2) * score_eval_wrapper(x, time_steps)

  # Run the black-box ODE solver.
  res = integrate.solve_ivp(ode_func, (1., eps), init_x.reshape(-1).cpu().numpy(), rtol=rtol, atol=atol, method='RK45')
  x = torch.tensor(res.y[:, -1], device=device).reshape(shape)
  return x


class GaussianMixtureDataset(Dataset):
    def __init__(self, n_samples, means, covs, weights):
        self.n_samples = n_samples
        self.means = means
        self.covs = covs
        self.weights = weights
        self.n_components = len(means)
        self.data, self.labels = self._generate_data()

    def _generate_data(self):
        data = []
        labels = []
        
        comp_indices = np.random.choice(self.n_components, size=self.n_samples, p=self.weights)
        for i in range(self.n_components):
           uniques, counts = np.unique(comp_indices, axis=0, return_counts=True)
           #print(uniques, counts)
        for comp in comp_indices:
            sample = np.random.multivariate_normal(self.means[comp], self.covs[comp])
            data.append(sample)
            labels.append(comp)
        return np.array(data, dtype=np.float32), np.array(labels, dtype=np.int64)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    
class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
  
class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)
    
class ScoreNet_1HiddenLayerFC(nn.Module):
  """A time-dependent score-based model built upon 1-hidden-layer fully-connected NN architecture."""

  def __init__(self, d, marginal_prob_std, sigma, hidden_dim=16, embed_dim=4):
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, hidden_dim))
    self.dense1 = Dense(d, hidden_dim)
    self.dense2 = Dense(hidden_dim, d)

    self.act = lambda x: x * torch.relu(x)
    self.marginal_prob_std = marginal_prob_std
    self.sigma = sigma
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.embed(t) 
    # Feature map
    h = self.dense1(x)
    h += embed
    h = self.dense2(self.act(h))
    h = h / 16

    # Normalize output
    h = h / marginal_prob_std(t, self.sigma)[:, None]
    return h
    
class Diffusion(nn.Module):
    def __init__(
        self,
        eps_model: nn.Module,
        n_T: int,
        sigma: float
    ) -> None:
        super(Diffusion, self).__init__()
        self.eps_model = eps_model
        self.n_T = n_T
        self.sigma = sigma
    
    def score(self, x_t: torch.Tensor, _ts) -> torch.Tensor:
        return self.eps_model(x_t, _ts)

    def forward(self, x: torch.Tensor, eps=1e-5) -> torch.Tensor:
        
        random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
        z = torch.randn_like(x)
        std = marginal_prob_std(random_t, self.sigma)
        perturbed_x = x + z * std[:, None]
        score = self.eps_model(perturbed_x, random_t)
        loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=(1)))
        return loss, perturbed_x, random_t

    
def sample_means(n_components, d, bound=5.0, max_iters=100000):

    means = []
    iters = 0

    while len(means) < n_components and iters < max_iters:
        cand = np.random.uniform(-bound, bound, size=d)
        if all(np.linalg.norm(cand - m) >= np.sqrt(d) for m in means):
            means.append(cand)
        iters += 1

    if len(means) < n_components:
        raise ValueError(
            f"Only found {len(means)}/{n_components} means after {max_iters} trials. "
            "Try increasing `bound` or reducing `c0`."
        )

    return np.vstack(means)

def fit_gaussian(X):
    mu = np.mean(X, axis=0)
    sigma = np.cov(X, rowvar=False, ddof=1)
    return mu, sigma

def frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    diff = mu1 - mu2
    covmean, _ = sqrtm(sigma1.dot(sigma2), disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2*covmean)
    return fid

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--d', type=int, default=2)
    parser.add_argument('--n_components', type=int, default=4)
    parser.add_argument('--n_samples', type=int, default=300)
    parser.add_argument('--n_epochs', type=int, default=5000)
    parser.add_argument('--tau', type=float, default=0.001)
    parser.add_argument('--alpha', type=float, default=1.2)
    parser.add_argument('--sigma', type=float, default=25.0)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--path', type=str, default='./contents/defalt')
    parser.add_argument('--log_path_name', type=str, default='defalt')
    parser.add_argument('--seed', type=int, default=1234, help='Random seed')
    args = parser.parse_args()

    args = parser.parse_args()
    d = args.d
    n_components = args.n_components
    n_samples = args.n_samples
    num_epochs = args.n_epochs
    tau = args.tau
    alpha = args.alpha
    sigma = args.sigma
    learning_rate = args.lr
    path = args.path
    log_path_name = args.log_path_name
    log_path = './log/' + log_path_name

    ckpt_dir="checkpoints/"
    os.makedirs(ckpt_dir + log_path_name, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # set random seed
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    logging.basicConfig(
        filename=log_path, 
        filemode='a', 
        level=logging.INFO,
        format='%(asctime)s %(message)s'
    )
    logger = logging.getLogger('train')

    logger.info('-' * 19)
    for k, v in vars(args).items():
        logger.info("%s = %r", k, v)

    bound = 2*np.sqrt(d)
    means = sample_means(n_components, d, bound=bound)
    #means = [[0, 1.5], [-1.5, 0], [1.5, 0], [0, -1.5]]
    covs = [np.eye(d) for _ in range(n_components)]
    weights = np.arange(1, n_components+1) ** (-alpha)
    weights /= weights.sum()


    # Create dataset 
    dataset = GaussianMixtureDataset(n_samples, means, covs, weights)
    dataset, labels = dataset._generate_data()

    dms = {}
    optimizers = {}
    min_norms = {}

    for comp in range(n_components):
        dms[comp] = Diffusion(eps_model=ScoreNet_1HiddenLayerFC(d, marginal_prob_std=1.2, sigma=sigma), n_T=100, sigma=sigma).to(device)
        optimizers[comp] = optim.Adam(dms[comp].parameters(), lr=learning_rate)
        min_norms[comp] = 0
        dms[comp].eval()

    for comp in range(n_components, 2*n_components):
        dms[comp] = copy.deepcopy(dms[comp - n_components])
        optimizers[comp] = optim.Adam(dms[comp].parameters(), lr=learning_rate)
        dms[comp].eval()

    # Vanilla Training
    for epoch in range(num_epochs):
        for label_int in range(n_components):
            mask = (labels == label_int)
            X_batch = dataset[mask]
            data = torch.tensor(X_batch).to(device)

            dms[label_int + n_components].train()
            
            loss_v, x_noise_v, t_noise_v = dms[label_int + n_components](data)

            optimizers[label_int + n_components].zero_grad()
            loss_v.backward()
            optimizers[label_int + n_components].step()
            
            ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}_vanilla.pth")
            torch.save(dms[label_int + n_components].state_dict(), ckpt_path)

            dms[label_int + n_components].eval()

        #print(f"Epoch {epoch+1}/{num_epochs}, vanilla loss={loss_v.item():.4f}")

    # Mutual Learning
    for epoch in range(num_epochs):
        for label_int in range(n_components):
            mask = (labels == label_int)
            X_batch = dataset[mask]
            data = torch.tensor(X_batch).to(device)
            
            #if use_mutual:
            dms[label_int].train()
            
            loss, x_noise, t_noise = dms[label_int](data)

            mutual_err_sum = 0
            for i in range(n_components):
                if i != label_int:
                    error_term = dms[label_int].score(x_noise, t_noise)-dms[i].score(x_noise, t_noise)
                    norm_square = torch.exp(t_noise) * error_term.norm(p=2, dim=1)**2
                    mutual_err_sum += torch.mean(norm_square)

            loss_m = loss + tau * mutual_err_sum/(n_components - 1)
            
            optimizers[label_int].zero_grad()
            loss_m.backward()
            optimizers[label_int].step()
            
            grads = []
            
            for name, param in dms[label_int].named_parameters():
                if param.grad is not None:
                    grads.append(param.grad.view(-1))

            all_grads = torch.cat(grads)
            min_norms[label_int] = all_grads.norm().item()
                            
            dms[label_int].eval()

        min_norms_value = np.array(list(min_norms.values()))

        if epoch == 0:
            max_grad_norm = min_norms_value.max()
            for label_int in range(n_components):
                ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}.pth")
                torch.save(dms[label_int].state_dict(), ckpt_path)
        else:
            if min_norms_value.max() <= max_grad_norm:
                max_grad_norm = min_norms_value.max()
                for label_int in range(n_components):
                    ckpt_path = os.path.join(ckpt_dir + log_path_name, f"model{label_int}.pth")
                    torch.save(dms[label_int].state_dict(), ckpt_path)

        
        #print(f"Epoch {epoch+1}/{num_epochs}, mutual loss={loss_m.item():.4f}")


    with torch.no_grad():
        Worst_FID_mutual = 0
        Worst_FID_vanilla = 0
        sample_batch_size = 100
        sampler = ode_sampler
        for i in range(n_components):
            mean_vec = means[i]
            
            #mutual sampling
            net_path_m = os.path.join(ckpt_dir + log_path_name, f"model{i}.pth")
            ckpt = torch.load(net_path_m, map_location=device)
            dms[i].load_state_dict(ckpt)
            t = torch.ones(sample_batch_size, device=device)
            with torch.random.fork_rng():
                torch.manual_seed(random.randrange(2**63))
                init_x = torch.randn(sample_batch_size, d, device=device) * marginal_prob_std(t, sigma)[:, None]
            
            samples_m = sampler(dms[i],
                            marginal_prob_std,
                            diffusion_coeff,
                            sample_batch_size,
                            d,
                            init_x=init_x,
                            device=device)
            
            samples_m = samples_m.cpu().numpy()
            mu_fit_m, cov_fit_m = fit_gaussian(samples_m)
            fid_score_m = frechet_distance(mean_vec, np.identity(d), mu_fit_m, cov_fit_m)
            if fid_score_m>= Worst_FID_mutual:
               Worst_FID_mutual = fid_score_m
            logger.info(f"Epoch {epoch+1:4d} | Component {i:4d} | FID_mutual {fid_score_m.item():.4f} | Worst Mutual FID {Worst_FID_mutual.item():.4f}")
            
            #vanilla sampling
            net_path = os.path.join(ckpt_dir + log_path_name, f"model{i}_vanilla.pth")
            ckpt = torch.load(net_path, map_location=device)
            dms[i+n_components].load_state_dict(ckpt)
            samples = sampler(dms[i+n_components],
                            marginal_prob_std,
                            diffusion_coeff,
                            sample_batch_size,
                            d,
                            init_x=init_x,
                            device=device)
            samples = samples.cpu().numpy()
            mu_fit, cov_fit = fit_gaussian(samples)
            fid_score = frechet_distance(mean_vec, np.identity(d), mu_fit, cov_fit)
            if fid_score>= Worst_FID_vanilla:
               Worst_FID_vanilla = fid_score
            logger.info(f"Epoch {epoch+1:4d} | Component {i:4d} | FID_vanilla {fid_score.item():.4f} | Worst Vanilla FID {Worst_FID_vanilla.item():.4f}")


        logger.info('-' * 19)

if __name__ == "__main__":
    main()